import os
import zipfile
import tarfile

import pytorch_lightning as pl
import requests
from torch.utils.data import DataLoader
from torchvision import transforms as T
from torchvision.datasets import CIFAR10,CIFAR100
from tqdm import tqdm
from enum import Enum

from sparsezoo.utils import download_file
import random
from typing import Union
from torchvision import transforms
from torchvision.datasets import ImageFolder
import numpy as np

#Notes:
### CIFAR100 Pretraining followed the LR and epoch setups from https://github.com/weiaicunzai/pytorch-cifar100
### ImageNette used Imagenet Torchvision pretrained weights


class Dataset(pl.LightningDataModule):
    def __init__(self, data_dir, dataset):
        super().__init__()
        self.data_dir = data_dir
        self.dataset = dataset
        self.mean = (0.4914, 0.4822, 0.4465)
        self.std = (0.2471, 0.2435, 0.2616)
        
        if self.dataset == "Imagenette2":
            self.mean = (0.485, 0.456, 0.406)
            self.std = (0.229, 0.224, 0.225)
        elif self.dataset == "CIFAR100":
            self.mean = (0.5071, 0.4867, 0.4408)
            self.std = (0.2675, 0.2565, 0.2761)
        elif self.dataset == "MNIST":
            self.mean = (0.1307,)
            self.std = (0.3081,)



    def train_dataloader(self):
        print("mean and std: ", self.mean," ",self.std)
        if self.dataset=="Imagenette2":
            data = ImagenetteDataset(root=self.data_dir)         
            dataloader = DataLoader(
                data,
                batch_size=32,
                num_workers=1,
                shuffle=True,
                drop_last=True,
                pin_memory=True,
            )
        elif self.dataset=="CIFAR10":
            transform = T.Compose([
                    T.ToTensor(),
                    T.Normalize(self.mean, self.std),
                ])
            data = CIFAR10(root=self.data_dir, download=True, train=True, transform=transform)

            dataloader = DataLoader(
                data,
                batch_size=512,#self.hparams.batch_size,
                num_workers=1,#self.hparams.num_workers,
                shuffle=True,
                drop_last=True,
                pin_memory=True,
            )
        elif self.dataset=="CIFAR100":
            transform = T.Compose([
                    T.ToTensor(),
                    T.Normalize(self.mean, self.std),
                ])             
          
            data = CIFAR100(root=self.data_dir, download=True, train=True, transform=transform)

            dataloader = DataLoader(
                data,
                batch_size=512,#self.hparams.batch_size,
                num_workers=1,#self.hparams.num_workers,
                shuffle=True,
                drop_last=True,
                pin_memory=True,
            )
        return dataloader

    ### As the datasets all provide training and testing data, we derive the validation set by splitting the testing data 50/50 between
    ###     validation and testing sets
    def val_dataloader(self):
        if self.dataset=="CIFAR10":
            transform = T.Compose([
                    T.ToTensor(),
                    T.Normalize(self.mean, self.std),
                ])        
            data = CIFAR10(root=self.data_dir, download=True, train=False, transform=transform)
            data.data = data.data[:np.ceil(len(data.data)/2).astype(int)]
            data.targets = data.targets[:np.ceil(len(data.targets)/2).astype(int)]
            dataloader = DataLoader(
                data,
                batch_size=512,#self.hparams.batch_size,
                num_workers=1,#self.hparams.num_workers,
                drop_last=True,
                pin_memory=True,
            )
        elif self.dataset=="CIFAR100":
            transform = T.Compose([
                    T.ToTensor(),
                    T.Normalize(self.mean, self.std),
                ])        
            data = CIFAR100(root=self.data_dir, download=True, train=False, transform=transform)
            data.data = data.data[:np.ceil(len(data.data)/2).astype(int)]
            data.targets = data.targets[:np.ceil(len(data.targets)/2).astype(int)]
            
            dataloader = DataLoader(
                data,
                batch_size=512,#self.hparams.batch_size,
                num_workers=1,#self.hparams.num_workers,
                drop_last=True,
                pin_memory=True,
            )
        return dataloader



    def test_dataloader(self):
        if self.dataset=="Imagenette2":
            data = ImagenetteDataset(root=self.data_dir, train=False)         
            dataloader = DataLoader(
                data,
                batch_size=32,#self.hparams.batch_size,
                num_workers=1,#self.hparams.num_workers,
                shuffle=True,
                drop_last=True,
                pin_memory=True,
            )
        elif self.dataset=="CIFAR10":
            transform = T.Compose([
                    T.ToTensor(),
                    T.Normalize(self.mean, self.std),
                ])        
            data = CIFAR10(root=self.data_dir, download=True, train=False, transform=transform)
            data.data = data.data[np.ceil(len(data.data)/2).astype(int):]
            data.targets = data.targets[np.ceil(len(data.targets)/2).astype(int):]
            dataloader = DataLoader(
                data,
                batch_size=512,#self.hparams.batch_size,
                num_workers=1,#self.hparams.num_workers,
                drop_last=True,
                pin_memory=True,
            )
        elif self.dataset=="CIFAR100":
            transform = T.Compose([
                    T.ToTensor(),
                    T.Normalize(self.mean, self.std),
                ])        
            data = CIFAR100(root=self.data_dir, download=True, train=False, transform=transform)
            data.data = data.data[np.ceil(len(data.data)/2).astype(int):]
            data.targets = data.targets[np.ceil(len(data.targets)/2).astype(int):]
            
            dataloader = DataLoader(
                data,
                batch_size=512,#self.hparams.batch_size,
                num_workers=1,#self.hparams.num_workers,
                drop_last=True,
                pin_memory=True,
            )
        return dataloader










    ### Functions to directly return the datasets rather than dataloaders

    def train_dataset(self):
        if self.dataset=="Imagenette2":
            dataset = ImagenetteDataset(root=self.data_dir)         
        elif self.dataset=="CIFAR10":
            transform = T.Compose([
                    T.ToTensor(),
                    T.Normalize(self.mean, self.std),
                ])   
            dataset = CIFAR10(root=self.data_dir, download=True, train=True, transform=transform)
        elif self.dataset=="CIFAR100":
            transform = T.Compose([
                    T.ToTensor(),
                    T.Normalize(self.mean, self.std),
                ])   
            dataset = CIFAR100(root=self.data_dir, download=True, train=True, transform=transform)
        
        return dataset
        
        
        
        
    def val_dataset(self):
        if self.dataset=="Imagenette2":
            dataset = ImagenetteDataset(root=self.data_dir, train=False)

        elif self.dataset=="CIFAR10":
            transform = T.Compose([
                    T.ToTensor(),
                    T.Normalize(self.mean, self.std),
                ])   
            dataset = CIFAR10(root=self.data_dir, download=True, train=False, transform=transform)
            dataset.data = dataset.data[:np.ceil(len(dataset.data)/2).astype(int)]
            dataset.targets = dataset.targets[:np.ceil(len(dataset.targets)/2).astype(int)]
            
        elif self.dataset=="CIFAR100":
            transform = T.Compose([
                    T.ToTensor(),
                    T.Normalize(self.mean, self.std),
                ])   
            dataset = CIFAR100(root=self.data_dir, download=True, train=False, transform=transform)
            dataset.data = dataset.data[:np.ceil(len(dataset.data)/2).astype(int)]
            dataset.targets = dataset.targets[:np.ceil(len(dataset.targets)/2).astype(int)]
            
        
        return dataset
        
        
    def test_dataset(self):
        if self.dataset=="Imagenette2":
            dataset = ImagenetteDataset(root=self.data_dir, train=False)
        elif self.dataset=="CIFAR10":
            transform = T.Compose([
                    T.ToTensor(),
                    T.Normalize(self.mean, self.std),
                ])   
            dataset = CIFAR10(root=self.data_dir, download=True, train=False, transform=transform)
            dataset.data = dataset.data[np.ceil(len(dataset.data)/2).astype(int):]
            dataset.targets = dataset.targets[np.ceil(len(dataset.targets)/2).astype(int):]
        elif self.dataset=="CIFAR100":
            transform = T.Compose([
                    T.ToTensor(),
                    T.Normalize(self.mean, self.std),
                ])   
            dataset = CIFAR100(root=self.data_dir, download=True, train=False, transform=transform)
            dataset.data = dataset.data[np.ceil(len(dataset.data)/2).astype(int):]
            dataset.targets = dataset.targets[np.ceil(len(dataset.targets)/2).astype(int):]
            
        
        return dataset
        
        
        










### Code for downloading and extracting the Imagenette dataset



### From https://github.com/neuralmagic/sparseml/blob/main/src/sparseml/pytorch/datasets/classification/imagenette.py
class ImagenetteDownloader(object):
    """
    Downloader implementation for the imagenette dataset.
    More info on the dataset can be found
    `here <https://github.com/fastai/imagenette>`__
    :param download_root: the local path to download the files to
    :param dataset_size: which dataset size to download
    :param download: True to run the download, False otherwise.
        If False, dataset must already exist at root.
    """

    def __init__(
        self, download_root: str, dataset_size: int, download: bool, train: bool
    ):
        self._download_root = download_root
        self._dataset_size = dataset_size
        self._download = download
        self.train = train

        print("Download root is: ", download_root)

        self._extract_name = "imagenette2"

        self._extracted_root = os.path.join(self._download_root, self._extract_name)
        print("Extracted root is: ", self._extracted_root)

        if download:
            self._download_and_extract()
        else:
            file_path = "{}.tar".format(self._extracted_root)

            if not os.path.exists(file_path):
                raise ValueError(
                    "could not find original tar for the dataset at {}".format(
                        file_path
                    )
                )
        print("cc")

    @property
    def download_root(self) -> str:
        """
        :return: the local path to download the files to
        """
        return self._download_root

    @property
    def dataset_size(self) -> int:
        """
        :return: which dataset size to download
        """
        return self._dataset_size

    @property
    def download(self) -> bool:
        """
        :return: True to run the download, False otherwise.
            If False, dataset must already exist at root.
        """
        return self._download

    @property
    def extracted_root(self) -> str:
        """
        :return: Where the specific dataset was extracted to
        """
        return self._extracted_root

    def split_root(self, train: bool) -> str:
        """
        :param train: True to get the path to the train dataset, False for validation
        :return: The path to the desired split for the dataset
        """
        path = os.path.join(self.extracted_root, "train" if train else "val")
        os.makedirs(path,exist_ok=True)

        return path

    def _download_and_extract(self):
        url = "https://s3.amazonaws.com/fast-ai-imageclas/imagenette2.tgz"
        print("making directory")
        os.makedirs(self._extracted_root,exist_ok=True)
        file_path = "{}.tar".format(self._extracted_root)

        if os.path.exists(file_path):
            print("already downloaded imagenette2 {}".format(self._dataset_size))

            return
        print("download_file")

        download_file(
            url,
            file_path,
            overwrite=False,
            progress_title="downloading imagenette2 {}".format(self._dataset_size),
        )
        print("extracting tar")
        # os.makedirs(self.split_root(self.train),exist_ok=True)

        with tarfile.open(file_path, "r:gz") as tar:
            tar.extractall(path=self.download_root)
           
           
           
           
           
           
           
           
           
           
           
           
class ImagenetteDataset(ImagenetteDownloader, ImageFolder):
    """
    Wrapper for the imagenette (10 class) dataset that fastai created.
    Handles downloading and applying standard transforms.
    :param root: The root folder to find the dataset at,
        if not found will download here if download=True
    :param train: True if this is for the training distribution,
        False for the validation
    :param rand_trans: True to apply RandomCrop and RandomHorizontalFlip to the data,
        False otherwise
    :param dataset_size: The size of the dataset to use and download:
        See ImagenetteSize for options
    :param image_size: The image size to output from the dataset
    :param download: True to download the dataset, False otherwise
    """

    def __init__(
        self,
        root: str = "../data/imagenette2",
        train: bool = True,
        rand_trans: bool = False,
        dataset_size: str = "full",
        image_size: Union[int, None] = None,
        download: bool = False,
    ):
        print("Root is: ", root)
        ImagenetteDownloader.__init__(self, root, dataset_size, download, train)
        print("b")

        image_size = 224

        if rand_trans:
            trans = [
                transforms.RandomResizedCrop(image_size),
                transforms.RandomHorizontalFlip(),
            ]
        else:
            resize_scale = 256.0 / 224.0  # standard used
            trans = [
                transforms.Resize(round(resize_scale * image_size)),
                transforms.CenterCrop(image_size),
            ]

        trans.extend(
            [
                transforms.ToTensor(),
                transforms.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
            ]
        )
        print("Split root is: ", self.split_root(train))

        ImageFolder.__init__(self, self.split_root(train), transforms.Compose(trans))
        print("d")

        self.image_size = image_size
        self.rand_trans = rand_trans

        # make sure we don't preserve the folder structure class order
        random.shuffle(self.samples)

